import numpy as np
from math import fabs
from functions import updated_signal


def find_tip_spatial_dist(mat, ntips, hexagonal=True, return_raw=False):
    '''
    Compute the distribution of tip distances from a given tip cell
    mat: nxn matrix with scores
    hexagonal: always true
    return_raw: if false, explicitly return the distance of closest, second closest and third closest tips
    '''
    n, m = mat.shape  # score.shape
    x_tip = np.array([])
    y_tip = np.array([])

    # get coordinates of cells with score>0
    for i in range(n):
        for j in range(m):
            if mat[i][j] > 0.:
                x_tip = np.append(x_tip, i)
                y_tip = np.append(y_tip, j)

    all_dist = np.zeros((x_tip.size, x_tip.size - 1))
    thr = 6

    # select_final = np.zeros((x_tip.size, ntips))  # 6
    select_final = [[] for k in range(x_tip.size)]  # now it's a list of lists

    coords = np.zeros((x_tip.size, 2))  # lattice coordinates of tips

    # for each tip in the lattice, construct the distribution of closest tips vs distance
    for k in range(x_tip.size):

        # select coordinates of all other tips (i.e. excluding  cell k)
        x_rem = np.delete(x_tip, k)
        y_rem = np.delete(y_tip, k)
        dist_vec = np.zeros(x_tip.size - 1)
        dat = np.zeros((x_tip.size - 1, 3))

        coords[k] = [x_tip[k], y_tip[k]]

        # loop over all remaining tips
        for q in range(x_rem.size):
            dist_vec[q] = new_hex_dist(x_tip[k], y_tip[k], x_rem[q], y_rem[q], n, m)
            dat[q][0] = dist_vec[q]
            dat[q][1] = x_rem[q]
            dat[q][2] = y_rem[q]
        x_coord = [x_rem for _, x_rem in sorted(zip(dist_vec, x_rem))]
        y_coord = [y_rem for _, y_rem in sorted(zip(dist_vec, y_rem))]
        all_dist[k] = np.sort(dist_vec)


        # here the part about shielded cells begins

        dat = dat[dat[:, 0].argsort()]

        # arrays for all tips distances and coordinates
        dist = dat[:, 0]
        x = dat[:, 1]
        y = dat[:, 2]

        # arrays to store selected tips
        #select_tip = np.array([])
        select_x = np.array([])
        select_y = np.array([])
        select_tip = []

        # select only tips that are not shielded by closer tips
        i = 0
        for i in range(thr):  # 6
            if i == 0:
                # select_tip = np.append(select_tip, dist[0])
                select_x = np.append(select_x, x_coord[0])
                select_y = np.append(select_y, y_coord[0])
                select_tip.append(dist[0])
            else:
                xt, yt = x[i], y[i]
                xi, yi = x_tip[k], y_tip[k]
                cross = False

                for j in range(len(select_tip)):
                    R = 0.5
                    xc = select_x[j]
                    yc = select_y[j]

                    if xi != xt:
                        a = yi - xi * (yt - yi) / (xt - xi)
                        b = (yt - yi) / (xt - xi)
                        p2 = 1 + b * b
                        p1 = -2 * xc - 2 * b * (a - yc)
                        p0 = xc * xc + (a - yc) * (a - yc) - R * R
                    else:
                        a = 0.
                        b = 0.
                        p2 = 1
                        p1 = -2 * yc
                        p0 = yc * yc + (xi - xc) * (xi - xc) - R * R
                    point1, point2 = np.roots([p2, p1, p0])
                    if not (isinstance(point1, complex)):
                        cross = True

                if not (cross):
                    # select_tip = np.append(select_tip, dist[i])
                    select_x = np.append(select_x, x_coord[i])
                    select_y = np.append(select_y, y_coord[i])
                    select_tip.append(dist[i])

        select_final[k] = select_tip
    return select_final

def new_hex_dist(x1, y1, x2, y2, n, m):
    '''
    Convert between coordinates and spatial coordinates
    x: x-coordinate that is shifted every other row
    y: y-coordinate
    '''

    # PBCs on x-axis
    if (x2-x1)>n/2:
        x1 = x1 + n
    elif (x1-x2)>n/2:
        x2 = x2 + n
    # PBCs on y-axis
    if (y2-y1)>m/2:
        y1 = y1 + m
    elif (y1-y2)>m/2:
        y2 = y2 + m

    # invert points if x1>x2

    if x1>x2:
        a, b = x1, y1
        c, d = x2, y2
        x1, y1 = c, d
        x2, y2 = a, b

    dx = fabs(x1 - x2)
    dy = fabs(y1 - y2)

    if dx==0 or dy==0:
        return dx + dy
    elif dy%2==0:
        return dx + dy - 1
    else: # if dy%2!=0
        if y1%2==0:
            return dx + dy - 1
        else:
            return dx + dy


def stoc_select(mat):
    n, m = mat.shape
    score_spr = mat
    for i in range(n):
        for j in range(n):
            if score_spr[i][j] > 0. and np.random.uniform(low=0, high=1) < 0.246:
                score_spr[i][j] = 1
            else:
                score_spr[i][j] = -1
    return score_spr

def compute_neigh_tips(mat, hexagonal=True):
    '''
    mat = score matrix
    returns the sprout matrix in the 6-stalk model
    '''
    # create mat where tip=1 and stalk=0
    n, m = mat.shape
    tip_mat = np.zeros((n, n), dtype='int')
    for i in range(n):
        for j in range(n):
            if mat[i][j] > 0.:
                tip_mat[i][j] = 1
    # number of neighbor tips
    tip_neigh = np.asarray(6*updated_signal(tip_mat, hexagonal=hexagonal), dtype='int')
    # matrix of sprouts in 6-stalk model
    sprout_mat = np.zeros((n, n))
    for i in range(n):
        for j in range(n):
            if tip_mat[i][j] == 1 and tip_neigh[i][j] == 0:
                sprout_mat[i][j] = 1
            else:
                sprout_mat[i][j] = -1
    return sprout_mat


def HSR_sprout_pattern(mat, n_sprout):
    '''
    construct sprout pattern in hard-sphere repulsion model
    '''
    # starting matrices
    n, m = mat.shape
    stalk, tip, sprout = np.zeros((n, n)), np.zeros((n, n)), np.zeros((n, n))
    for i in range(n):
        for j in range(n):
            if mat[i][j] <= 0.:
                stalk[i][j] = 1
            else:
                tip[i][j] = 1

    while np.sum(sprout) < n_sprout:
        neigh_spr = 6 * updated_signal(sprout)
        found = False
        while found == False:
            i, j = np.random.randint(low=0, high=n, size=2)
            if tip[i][j] == 1 and neigh_spr[i][j]==0:
                tip[i][j] = 0
                sprout[i][j] = 1
                found = True
    return sprout

